import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


def get_loss_function(loss_type, full_package):
    # Note: we can use information from full_package to customize the loss function
    if loss_type == "mse":
        loss_fn = nn.MSELoss()
    elif loss_type == "ce":
        loss_fn = nn.CrossEntropyLoss()
    elif loss_type == "bce":
        loss_fn = nn.BCELoss()
    elif loss_type == "ls":
        loss_fn = smoothed_loss()
    # elif loss_type == "nls":
    #     loss_fn = nls()
    elif loss_type == "fdiv":
        loss_fn = fdiv()
    elif loss_type == "taylorCE":
        loss_fn = taylorCE(full_package["taylor_factor"])
    elif loss_type == "jocor":
        loss_fn = jocor()
    elif loss_type == "drops":
        loss_fn = (
            nn.CrossEntropyLoss()
        )  # TODO: What is the difference between CE and drops
    else:
        raise ValueError(f"Unsupported loss type: {loss_type}")

    full_package["criterion"] = loss_fn
    return full_package


class spl(nn.Module):
    def __init__(self, initial_epoch=20):
        super(spl, self).__init__()
        self.initial_epoch = initial_epoch
        self.cross_entropy = nn.CrossEntropyLoss(reduction="none")

    def f_spl_alpha_hard(self, epoch):
        # Implement your logic for calculating alpha based on the epoch
        # Placeholder logic (you should replace this with your actual logic)
        alpha = 0.5
        return alpha

    def forward(self, y, t, epoch, **kwargs):
        if epoch <= self.initial_epoch:
            return self.cross_entropy(y, t)
        else:
            alpha = self.f_spl_alpha_hard(epoch)
            loss = self.cross_entropy(y, t)

            loss_v = (loss <= alpha).float()
            loss_ = loss_v * loss

            loss_v_sum = torch.sum(loss_v)
            if loss_v_sum.item() == 0:
                return torch.mean(loss_) / 1e8
            else:
                return torch.sum(loss_) / loss_v_sum


class cores(nn.Module):
    def __init__(self, num_epochs, beta_max=2.0, noise_prior=None):
        super(cores, self).__init__()
        self.beta_max = beta_max
        self.noise_prior = noise_prior
        self.cross_entropy = nn.CrossEntropyLoss(reduction="none")
        self.epoch_threshold = 30
        self.init_beta(num_epochs)

    def forward(self, y, t, epoch, **kwargs):
        beta = self.betas[epoch]
        loss = self.cross_entropy(y, t)
        loss_soft = -torch.log(F.softmax(y, dim=1) + 1e-8)
        loss_sel = loss - torch.mean(loss_soft, 1)

        if self.noise_prior is None:
            adjusted_loss = loss - beta * torch.mean(loss_soft, 1)
        else:
            adjusted_loss = loss - beta * torch.sum(self.noise_prior * loss_soft, 1)

        loss_v = ((epoch <= self.epoch_threshold) | (loss_sel <= -0.0)).float()

        # Apply the mask
        loss_ = loss_v * adjusted_loss

        # Calculate the final loss
        loss_v_sum = loss_v.sum()
        if loss_v_sum == 0:
            return torch.mean(loss_) / 1e8
        else:
            return torch.sum(loss_) / loss_v_sum  # , loss_v.int()

    def init_beta(self, num_epochs):
        beta1 = np.linspace(0.0, 0.0, num=10)
        beta2 = np.linspace(0.0, self.beta_max, num=30)
        beta3 = np.linspace(self.beta_max, self.beta_max, num=num_epochs - 10 - 30)
        self.betas = np.concatenate((beta1, beta2, beta3), axis=0)


class smoothed_loss(nn.Module):
    def __init__(self, starting_epoch=10, smooth_rate=0.6):
        super(smoothed_loss, self).__init__()
        self.smooth_rate = smooth_rate
        self.confidence = 1 - self.smooth_rate
        self.cross_entropy = nn.CrossEntropyLoss(reduction="none")
        self.starting_epoch = starting_epoch

    def forward(self, cur_e, y, t, **kwargs):
        assert y.shape[0] == t.shape[0], "y and t must have the same batch size"
        assert len(y.shape) == 2 and len(t.shape) == 1, "y must be 2D and t must be 1D"
        if cur_e < self.starting_epoch:
            return torch.sum(self.cross_entropy(y, t)) / y.shape[0]
        else:
            loss = self.cross_entropy(y, t)
            loss_ = -torch.log(F.softmax(y, dim=1) + 1e-8)
            loss = self.confidence * loss + self.smooth_rate * torch.mean(loss_, 1)
            return torch.sum(loss) / y.shape[0]


# class nls(nn.Module):
#     def __init__(self, smooth_rate=-6.0):
#         super(nls, self).__init__()
#         self.smooth_rate = smooth_rate
#         self.confidence = 1 - self.smooth_rate
#         self.cross_entropy = nn.CrossEntropyLoss(reduction="none")

#     def forward(self, y, t, **kwargs):
#         assert y.shape[0] == t.shape[0], "y and t must have the same batch size"
#         assert len(y.shape) == 2 and len(t.shape) == 1, "y must be 2D and t must be 1D"

#         loss = self.cross_entropy(y, t)
#         loss_ = -torch.log(F.softmax(y, dim=1) + 1e-8)
#         loss = self.confidence * loss + self.smooth_rate * torch.mean(loss_, 1)
#         return torch.sum(loss) / y.shape[0]


class peer_loss(nn.Module):
    def __init__(self, starting_epoch, alpha):
        super(peer_loss, self).__init__()
        self.alpha = alpha
        self.cross_entropy = nn.CrossEntropyLoss(reduction="none")


class fdiv(nn.Module):
    def __init__(self):
        super(fdiv, self).__init__()
        self.nill_loss = nn.NLLLoss(reduction="none")

    def activation(self, x):
        return -torch.mean(torch.tanh(x) / 2.0)

    def conjugate(self, x):
        return -torch.mean(torch.tanh(x) / 2.0)

    def forward(self, y, y_peer, t, t_peer, **kwargs):
        assert (
            y.shape[0] == t.shape[0] and y_peer.shape[0] == t_peer.shape[0]
        ), "y and t must have the same batch size"
        assert (
            len(y.shape) == 2
            and len(t.shape) == 1
            and len(y_peer.shape) == 2
            and len(t_peer.shape) == 1
        ), "y must be 2D and t must be 1D"

        prob_acti = -self.nill_loss(y, t)
        prob_conj = -self.nill_loss(y_peer, t_peer)
        loss = self.activation(prob_acti) - self.conjugate(prob_conj)

        return torch.sum(loss) / y.shape[0]


class dmi(nn.Module):
    def __init__(self, num_classes=10):
        super(dmi, self).__init__()
        self.num_classes = num_classes

    def forward(self, output, target, **kwargs):
        device = output.device
        outputs = F.softmax(output, dim=1)
        y_onehot = F.one_hot(target, num_classes=self.num_classes).to(
            dtype=torch.float, device=device
        )

        # Matrix multiplication and determinant
        mat = y_onehot.transpose(0, 1) @ outputs
        det_mat = torch.det(mat.float())

        # Return the negative log of the absolute determinant
        return -1.0 * torch.log(torch.abs(det_mat) + 0.001)


class fw(nn.Module):
    def __init__(self, num_classes=10):
        super(fw, self).__init__()
        self.num_classes = num_classes

    def forward(self, y, t, trans_mat, **kwargs):
        outputs = F.softmax(y, dim=1)
        outputs = outputs @ trans_mat.cuda()
        outputs = torch.log(outputs)
        # loss = CE(outputs, t)
        loss = F.cross_entropy(outputs, t)
        return loss


class bw(nn.Module):
    def __init__(self, num_classes=10):
        super(bw, self).__init__()
        self.num_classes = num_classes

    def forward(y, t, trans_mat, **kwargs):
        # l_{forward}(y, h(x)) = l_{ce}(y, h(x) @ T)
        trans_mat_inv = torch.inverse(trans_mat).cuda()
        outputs = F.softmax(y, dim=1)
        outputs = torch.log(outputs)
        # loss = CE(outputs, t)
        loss = -torch.mean(
            torch.sum(
                (F.one_hot(t, trans_mat.shape[0]).float() @ trans_mat_inv) * outputs,
                axis=1,
            ),
            axis=0,
        )
        # loss = F.cross_entropy(outputs,t @ trans_mat_inv)  # TODO
        return loss


class lq(nn.Module):
    def __init__(self, q=0.7):
        super(lq, self).__init__()
        self.q = q

    def forward(self, y, t, **kwargs):
        device = y.device  # Use the same device as the input

        outputs = F.softmax(y, dim=1) + 1e-12
        # Gather the relevant outputs using advanced indexing
        relevant_outputs = outputs[torch.arange(outputs.size(0)), t].to(device)

        # Vectorized loss computation
        loss = torch.mean((1.0 - relevant_outputs**self.q) / self.q)

        # Removed the NaN check for efficiency (consider re-adding it if debugging)
        # if torch.isnan(loss):
        #     print(outputs)
        return loss


class adj(nn.Module):
    def __init__(self, tau=1.0):
        super(adj, self).__init__()
        self.cross_entropy = nn.CrossEntropyLoss(reduction="none")
        self.tau = tau

    def forward(self, y, t, spc, **kwargs):
        device = y.device  # Use the same device as the input

        # Ensure spc is a tensor and on the correct device
        spc = torch.as_tensor(spc, dtype=torch.float32, device=device)
        spc /= torch.sum(spc)

        y = y + torch.log(spc**self.tau + 1e-12)
        loss = F.cross_entropy(y, t, reduction="none")  # Use 'none' for unreduced loss

        # Use torch.mean to simplify computation
        return torch.mean(loss)


class taylorCE(nn.Module):
    def __init__(self, n=6):
        super(taylorCE, self).__init__()
        self.n = n

    def forward(self, y, t, **kwargs):
        k = y.shape[1]
        device = y.device
        output = F.softmax(y, dim=1)
        label_one_hot = nn.functional.one_hot(t, k).float()
        final_outputs = (output * label_one_hot).sum(dim=1)

        loss = 0
        for i in range(self.n):
            loss += (
                (1 / (i + 1))
                * torch.pow(torch.tensor([-1]).to(device), i + 1)
                * torch.pow(final_outputs - 1, i + 1)
            )
        return loss.mean()


class jocor(nn.Module):
    def __init__(self):
        super(jocor, self).__init__()

    def kl_loss_compute(self, pred, soft_targets, reduce=True):
        kl = F.kl_div(
            F.log_softmax(pred, dim=1), F.softmax(soft_targets, dim=1), reduce=False
        )
        if reduce:
            return torch.mean(torch.sum(kl, dim=1))
        else:
            return torch.sum(kl, 1)

    def forward(self, y_1, y_2, t, forget_rate, co_lambda=0.1):
        device = y_1.device
        loss_pick_1 = F.cross_entropy(y_1, t, reduce=False) * (1 - co_lambda)
        loss_pick_2 = F.cross_entropy(y_2, t, reduce=False) * (1 - co_lambda)
        loss_pick = (
            loss_pick_1
            + loss_pick_2
            + co_lambda * self.kl_loss_compute(y_1, y_2, reduce=False)
            + co_lambda * self.kl_loss_compute(y_2, y_1, reduce=False)
        ).cpu()

        ind_sorted = np.argsort(loss_pick.data)
        loss_sorted = loss_pick[ind_sorted]

        remember_rate = 1 - forget_rate
        num_remember = int(remember_rate * len(loss_sorted))

        ind_update = ind_sorted[:num_remember]

        # exchange
        loss = torch.mean(loss_pick[ind_update])

        return loss, loss


class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha, (float, int)):
            self.alpha = torch.Tensor([alpha, 1 - alpha])
        if isinstance(alpha, list):
            self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim() > 2:
            input = input.view(input.size(0), input.size(1), -1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1, 2)  # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1, input.size(2))  # N,H*W,C => N*H*W,C
        target = target.view(-1, 1)

        logpt = F.log_softmax(input, dim=1)
        logpt = logpt.gather(1, target)
        logpt = logpt.view(-1)
        pt = logpt.exp()

        if self.alpha is not None:
            if self.alpha.type() != input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0, target.data.view(-1))
            logpt = logpt * at

        loss = -1 * (1 - pt) ** self.gamma * logpt
        if self.size_average:
            return loss.mean()
        else:
            return loss.sum()


class LDAMLoss(nn.Module):
    def __init__(self, cls_num_list, max_m=0.5, weight=None, s=30):
        super(LDAMLoss, self).__init__()
        m_list = 1.0 / np.sqrt(np.sqrt(cls_num_list))
        m_list = m_list * (max_m / np.max(m_list))
        m_list = torch.cuda.FloatTensor(m_list)
        self.m_list = m_list
        assert s > 0
        self.s = s
        self.weight = weight

    def forward(self, x, target):
        index = torch.zeros_like(x, dtype=torch.uint8)
        index.scatter_(1, target.data.view(-1, 1), 1)

        index_float = index.type(torch.cuda.FloatTensor)
        batch_m = torch.matmul(self.m_list[None, :], index_float.transpose(0, 1))
        batch_m = batch_m.view((-1, 1))
        x_m = x - batch_m

        output = torch.where(index, x_m, x)
        return F.cross_entropy(self.s * output, target, weight=self.weight)


class BalancedSoftmax(nn.Module):
    def __init__(self, sample_per_class, reduction):
        super(BalancedSoftmax, self).__init__()
        self.sample_per_class = torch.cuda.FloatTensor(sample_per_class)
        self.reduction = reduction

    def forward(self, input, target):
        spc = self.sample_per_class.type_as(input)
        spc = spc.unsqueeze(0).expand(input.shape[0], -1)
        input = input + spc.log()

        return F.cross_entropy(input=input, target=target, reduction=self.reduction)
